home *** CD-ROM | disk | FTP | other *** search
/ HPAVC / HPAVC CD-ROM.iso / SNNSV32.ZIP / SNNSv3.2 / kernel / sources / kr_inversion.c < prev    next >
C/C++ Source or Header  |  1994-04-25  |  8KB  |  277 lines

  1. /**********************************************************************
  2.   FILE           : @(#)kr_inversion.c    1.11
  3.  
  4.   SHORTNAME      : 
  5.   SNNS VERSION   : 3.2
  6.  
  7.   PURPOSE        : Implement Kindermann/Linden-inversion-method
  8.  
  9.   NOTES          : The functions implemented here closely resemble the
  10.                    the functions propagateNetForward and propagateNetBackward2 
  11.            in the file learn_f.c.
  12.   FUNCTIONS      : -- kr_initInversion
  13.                       Purpose : initialize net for inversion algorithm
  14.                    Calls   : int kr_topoCheck();
  15.                          int kr_IOCheck();
  16.                           int kr_topoSort();
  17.  
  18.                    -- kr_inv_forwardPass
  19.                    Purpose : topological forward propagation
  20.                   Calls   : nothing
  21.  
  22.                -- kr_inv_backwardPass
  23.                       Purpose : Backward error propagation (topological) 
  24.                   Calls   : nothing
  25.  
  26.   AUTHOR         : Guenter Mamier
  27.   DATE           : 04.02.92
  28.  
  29.   CHANGED BY     : Sven Doering, Michael Vogt
  30.   IDENTIFICATION : @(#)kr_inversion.c    1.11 3/15/94
  31.   SCCS VERSION   : 1.11
  32.   LAST CHANGE    : 3/15/94
  33.  
  34.              Copyright (c) 1990-1994  SNNS Group, IPVR, Univ. Stuttgart, FRG
  35.  
  36. **********************************************************************/
  37. #include <stdio.h>
  38. #include <math.h>
  39. #include <values.h>
  40.  
  41. #include "kr_typ.h"    
  42. #include "kr_const.h"
  43. #include "kernel.h"
  44. #include "kr_def.h"    
  45. #include "kr_mac.h"    
  46. #include "kr_inversion.ph"
  47.  
  48.  
  49.  
  50. /*****************************************************************************
  51.   FUNCTION : kr_initInversion
  52.  
  53.   PURPOSE  : initialize net for inversion algorithm
  54.   NOTES    :
  55.   UPDATE   : 06.02.92
  56. ******************************************************************************/
  57. int kr_initInversion(void)
  58. {
  59.     int ret_code = KRERR_NO_ERROR;
  60.  
  61.     if (NetModified || (TopoSortID != TOPOLOGICAL_FF &&
  62.             TopoSortID != TOPOLOGIC_LOGICAL)){
  63.        /*  Net has been modified or topologic array isn't initialized */
  64.        /*  check the topology of the network  */
  65.       ret_code = kr_topoCheck();
  66.     if (ret_code < KRERR_NO_ERROR)  
  67.        return( ret_code );  /*  an error has occured  */
  68.     if (ret_code < 2)  
  69.        return( KRERR_NET_DEPTH );  /*  the network has less then 2 layers  */
  70.  
  71.  
  72.     /*  count the no. of I/O units and check the patterns  */
  73.     ret_code = kr_IOCheck();
  74.     if(ret_code < KRERR_NO_ERROR)  
  75.        return( ret_code );
  76.  
  77.     /*  sort units by topology and by topologic type  */
  78.  
  79.     ret_code = kr_topoSort( TOPOLOGICAL_FF );
  80.   }
  81.   return(ret_code);
  82. }
  83.  
  84.  
  85.  
  86. /*****************************************************************************
  87.   FUNCTION : kr_inv_forwardPass
  88.  
  89.   PURPOSE  : topological forward propagation
  90.   NOTES    :
  91.   UPDATE   : 29.01.92
  92. ******************************************************************************/
  93. void  kr_inv_forwardPass(struct UnitList *inputs)
  94. {
  95.  
  96.    register struct Unit   *unit_ptr;
  97.    register TopoPtrArray  topo_ptr;    /* points to a topological sorted    */
  98.                        /* unit stucture (input units first) */
  99.    struct UnitList        *IUnit;      /* working list of input units       */
  100.  
  101.  
  102.    /* initialize the topological pointer */
  103.  
  104.    topo_ptr = topo_ptr_array;
  105.  
  106.  
  107.    /*  calculate the activation and output value of the input units */ 
  108.  
  109.    IUnit = inputs;
  110.    while((unit_ptr = *++topo_ptr) != NULL){
  111.  
  112.      /*  clear error values  */
  113.      unit_ptr->Aux.flint_no = 0.0;
  114.  
  115.      if(unit_ptr->out_func == OUT_IDENTITY)
  116.         unit_ptr->Out.output = unit_ptr->act = IUnit->act;
  117.      else  /* no identity output function: calculate unit's output also  */
  118.        unit_ptr->Out.output = (*unit_ptr->out_func)(unit_ptr->act = IUnit->act);
  119.      IUnit = IUnit->next;
  120.    }
  121.  
  122.  
  123.    /*  popagate hidden units  */
  124.  
  125.    while((unit_ptr = *++topo_ptr) != NULL){
  126.  
  127.      /*  clear error values  */
  128.      unit_ptr->Aux.flint_no = 0.0;
  129.  
  130.      /*  calculate the activation value of the unit: 
  131.      call the activation function if needed  */
  132.      unit_ptr->act = (*unit_ptr->act_func) (unit_ptr);
  133.  
  134.      if(unit_ptr->out_func == OUT_IDENTITY)
  135.        unit_ptr->Out.output = unit_ptr->act;
  136.      else
  137.        /*  no identity output function: calculate unit's output also  */
  138.        unit_ptr->Out.output = (*unit_ptr->out_func) (unit_ptr->act);
  139.    }
  140.  
  141.  
  142.    /*  popagate output units  */
  143.  
  144.    while((unit_ptr = *++topo_ptr) != NULL){
  145.  
  146.      /*  clear error values  */
  147.      unit_ptr->Aux.flint_no = 0.0;
  148.  
  149.      /*  calculate the activation value of the unit: 
  150.      call the activation function if needed  */
  151.      unit_ptr->act = (*unit_ptr->act_func) (unit_ptr);
  152.  
  153.      if(unit_ptr->out_func == OUT_IDENTITY)
  154.        unit_ptr->Out.output = unit_ptr->act;
  155.      else      /*  no identity output function: calculate unit's output also  */
  156.        unit_ptr->Out.output = (*unit_ptr->out_func) (unit_ptr->act);
  157.    }
  158. }
  159.  
  160.  
  161.  
  162. /*****************************************************************************
  163.   FUNCTION : kr_inv_backwardPass
  164.  
  165.   PURPOSE  : Backward error propagation (topological) 
  166.   NOTES    :
  167.   UPDATE   : 04.02.92
  168. *****************************************************************************/
  169. double kr_inv_backwardPass(float learn, float delta_max, int *err_units, 
  170.                float ratio, struct UnitList *inputs, 
  171.                struct UnitList *outputs)
  172. {
  173.    register struct Link   *link_ptr;
  174.    register struct Site   *site_ptr;
  175.    register struct Unit   *unit_ptr;
  176.    register float         error,  sum_error,  eta,  devit;
  177.    register TopoPtrArray  topo_ptr;
  178.    struct UnitList        *IUnit, *OUnit;
  179.  
  180.  
  181.    sum_error = 0.0;    /*  reset network error  */
  182.    *err_units = 0;     /*  reset error units */
  183.    eta = learn;        /*  store learn_parameter in CPU register  */
  184.  
  185.  
  186.    /* add 3 to no_of_topo_units because topologic array contains 4 NULL 
  187.       pointers  */
  188.  
  189.    topo_ptr = topo_ptr_array + (no_of_topo_units + 3);
  190.  
  191.  
  192.    /*  calculate output units only  */
  193.  
  194.    OUnit = outputs;
  195.    while(OUnit->next != NULL)OUnit = OUnit->next;
  196.    while((unit_ptr = *--topo_ptr) != NULL){
  197.  
  198.  
  199.      /*  calc. devitation */
  200.      devit = OUnit->i_act - unit_ptr->Out.output;
  201.      OUnit->act = unit_ptr->Out.output;
  202.      OUnit = OUnit->prev;
  203.      if ( (devit > -delta_max) && (devit < delta_max) ){
  204.        continue;
  205.      }else{
  206.        *err_units += 1;
  207.      }
  208.  
  209.      /*  sum up the error of the network  */
  210.      sum_error += devit * devit;  
  211.  
  212.      /*    calc. error for output units     */
  213.      error = devit * (unit_ptr->act_deriv_func) ( unit_ptr );
  214.      /*     error = devit;*/
  215.  
  216.      /* Calculate sum of errors of predecessor units  */
  217.      if(UNIT_HAS_DIRECT_INPUTS( unit_ptr )){
  218.        FOR_ALL_LINKS( unit_ptr, link_ptr )
  219.          link_ptr->to->Aux.flint_no += link_ptr->weight * error;
  220.      }else{        /*    the unit has sites  */
  221.        FOR_ALL_SITES_AND_LINKS( unit_ptr, site_ptr, link_ptr )
  222.          link_ptr->to->Aux.flint_no += link_ptr->weight * error;
  223.      }
  224.    }
  225.  
  226.  
  227.    /*  calculate hidden units only  */
  228.  
  229.    while((unit_ptr = *--topo_ptr) != NULL){
  230.  
  231.      /*    calc. the error of the (hidden) unit  */
  232.      error = (unit_ptr->act_deriv_func) ( unit_ptr ) * unit_ptr->Aux.flint_no;
  233.      error = unit_ptr->Aux.flint_no;
  234.  
  235.      /* Calculate sum of errors of predecessor units  */
  236.      if(UNIT_HAS_DIRECT_INPUTS( unit_ptr )){
  237.        FOR_ALL_LINKS( unit_ptr, link_ptr )
  238.       link_ptr->to->Aux.flint_no += link_ptr->weight * error;
  239.      }else{       /*  the unit has sites  */
  240.        FOR_ALL_SITES_AND_LINKS( unit_ptr, site_ptr, link_ptr )
  241.         link_ptr->to->Aux.flint_no += link_ptr->weight * error;
  242.      }
  243.      unit_ptr->act = unit_ptr->i_act;
  244.    }
  245.  
  246.  
  247.    /*  calculate input units only  */
  248.  
  249.    IUnit = inputs;
  250.    while(IUnit->next != NULL)IUnit = IUnit->next;
  251.    while((unit_ptr = *--topo_ptr) != NULL){
  252.  
  253.      /*    calc. the error of the (input) unit  */
  254.      error = (unit_ptr->act_deriv_func) ( unit_ptr ) * unit_ptr->Aux.flint_no;
  255.      error = unit_ptr->Aux.flint_no;
  256.  
  257.      /* Calculate the new activation for the input units */
  258.      IUnit->im_act += eta * error + ratio*(IUnit->i_act - (float)unit_ptr->act);
  259.      unit_ptr->act = 1.0 / (1.0 + exp((double)(-IUnit->im_act)));
  260.      IUnit->act = unit_ptr->act;
  261.      IUnit = IUnit->prev;
  262.    }
  263.  
  264.  
  265.    /*  return the error of the network */
  266.  
  267.    sum_error *= 0.5;
  268.    return( sum_error ); 
  269.  
  270.  
  271. }
  272.  
  273.  
  274.  
  275.  
  276.  
  277.